Source code for hysop.operator.base.custom_symbolic_operator

# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from abc import ABCMeta
import sympy as sm
import numpy as np

from hysop.tools.numpywrappers import npw
from hysop.tools.htypes import (
    check_instance,
    to_tuple,
    InstanceOf,
    first_not_None,
    to_set,
)
from hysop.tools.decorators import debug
from hysop.tools.sympy_utils import get_derivative_variables, SetupExprI
from hysop.fields.continuous_field import Field
from hysop.fields.discrete_field import DiscreteField, DiscreteScalarFieldView
from hysop.fields.field_requirements import DiscreteFieldRequirements
from hysop.parameters.scalar_parameter import ScalarParameter
from hysop.topology.topology import TopologyView
from hysop.operator.directional.directional import DirectionalOperatorBase
from hysop.backend.device.codegen.base.utils import SortedDict

from hysop.core.memory.memory_request import MemoryRequest
from hysop.topology.cartesian_descriptor import CartesianTopologyDescriptors

from hysop.symbolic import (
    Dummy,
    time_symbol,
    space_symbols,
    dspace_symbols,
    local_indices_symbols,
    global_indices_symbols,
)
from hysop.symbolic.relational import Assignment, AugmentedAssignment
from hysop.symbolic.misc import (
    TimeIntegrate,
    ApplyStencil,
    CodeSection,
    MutexLock,
    MutexUnlock,
    Cast,
)
from hysop.symbolic.tmp import TmpScalar
from hysop.symbolic.array import SymbolicArray, SymbolicBuffer, IndexedBuffer
from hysop.symbolic.field import AppliedSymbolicField, SymbolicDiscreteField
from hysop.symbolic.parameter import SymbolicScalarParameter

from hysop.constants import (
    ComputeGranularity,
    SpaceDiscretization,
    TranspositionState,
    DirectionLabels,
    SymbolicExpressionKind,
)
from hysop.numerics.odesolvers.runge_kutta import (
    TimeIntegrator,
    ExplicitRungeKutta,
    Euler,
    RK2,
    RK3,
    RK4,
)
from hysop.numerics.interpolation.interpolation import (
    MultiScaleInterpolation,
    Interpolation,
)
from hysop.numerics.stencil.stencil_generator import (
    StencilGenerator,
    CenteredStencilGenerator,
    MPQ,
)

ValidExpressions = (Assignment,)


[docs] class ExprDiscretizationInfo: SimpleCounterTypes = ( SymbolicArray, SymbolicBuffer, ) IndexedCounterTypes = (DiscreteScalarFieldView,) def __new__(cls, **kwds): return super().__new__(cls, **kwds) def __init__(self, **kwds): """ Helper class to store information about discretized symbolic expressions. """ super().__init__(**kwds) self.read_counter = SortedDict() self.write_counter = SortedDict() self.parameters = SortedDict()
[docs] def read(self, obj, index=None, count=1): check_instance(count, int) if isinstance(obj, self.IndexedCounterTypes): assert index is not None self.read_counter.setdefault( obj.dfield, npw.int_zeros(shape=(obj.nb_components,)) )[index] += count elif isinstance(obj, self.SimpleCounterTypes): self.read_counter.setdefault(obj, 0) self.read_counter[obj] += 1 else: msg = f"Unsupported type {type(obj)}." raise TypeError(msg)
[docs] def write(self, obj, index=None, count=1): check_instance(count, int) if isinstance(obj, self.IndexedCounterTypes): assert index is not None self.write_counter.setdefault( obj.dfield, npw.int_zeros(shape=(obj.nb_components,)) )[index] += count elif isinstance(obj, self.SimpleCounterTypes): self.write_counter.setdefault(obj, 0) self.write_counter[obj] += 1 else: msg = f"Unsupported type {type(obj)}." raise TypeError(msg)
[docs] def copy(self): edi = ExprDiscretizationInfo() for obj, counts in self.read_counter.items(): if isinstance(obj, self.IndexedCounterTypes): counter = edi.read_counter.setdefault( obj, npw.int_zeros(shape=(obj.nb_components,)) ) counter += counts elif isinstance(obj, self.SimpleCounterTypes): edi.read_counter[obj] = counts else: msg = f"Unsupported type {type(obj)}." raise TypeError(msg) for obj, counts in self.write_counter.items(): if isinstance(obj, self.IndexedCounterTypes): counter = edi.write_counter.setdefault( obj, npw.int_zeros(shape=(obj.nb_components,)) ) counter += counts elif isinstance(obj, self.SimpleCounterTypes): edi.write_counter[obj] = counts else: msg = f"Unsupported type {type(obj)}." raise TypeError(msg) edi.push_parameters(**self.parameters) return edi
[docs] def update(self, other): check_instance(other, ExprDiscretizationInfo) for obj, counts in other.read_counter.items(): if isinstance(obj, self.IndexedCounterTypes): counter = self.read_counter.setdefault( obj, npw.int_zeros(shape=(obj.nb_components,)) ) counter += counts elif isinstance(obj, self.SimpleCounterTypes): self.read_counter.setdefault(obj, 0) self.read_counter[obj] += counts else: msg = f"Unsupported type {type(obj)}." raise TypeError(msg) for obj, counts in other.write_counter.items(): if isinstance(obj, self.IndexedCounterTypes): counter = self.write_counter.setdefault( obj, npw.int_zeros(shape=(obj.nb_components,)) ) counter += counts elif isinstance(obj, self.SimpleCounterTypes): self.write_counter.setdefault(obj, 0) self.write_counter[obj] += counts else: msg = f"Unsupported type {type(obj)}." raise TypeError(msg) self.push_parameters(**other.parameters)
[docs] def push_parameters(self, *param, **kwd_params): self.parameters.update(**kwd_params) for p in param: self.parameters[p.name] = param
def __iadd__(self, rhs): check_instance(rhs, (np.integer, int)) rhs = int(rhs) for obj, counts in self.read_counter.items(): if isinstance(obj, self.IndexedCounterTypes): counts[counts > 0] += rhs elif isinstance(obj, self.SimpleCounterTypes): self.read_counter[obj] += rhs else: msg = f"Unsupported type {type(obj)}." raise TypeError(msg) for obj, counts in self.write_counter.items(): if isinstance(obj, self.IndexedCounterTypes): counts[counts > 0] += rhs elif isinstance(obj, self.SimpleCounterTypes): self.write_counter[obj] += rhs else: msg = f"Unsupported type {type(obj)}." raise TypeError(msg) return self def __imul__(self, rhs): check_instance(rhs, (np.integer, int)) rhs = int(rhs) for obj, counts in self.read_counter.items(): if isinstance(obj, self.IndexedCounterTypes): counts[...] = rhs * counts elif isinstance(obj, self.SimpleCounterTypes): self.read_counter[obj] *= rhs else: msg = f"Unsupported type {type(obj)}." raise TypeError(msg) for obj, counts in self.write_counter.items(): if isinstance(obj, self.IndexedCounterTypes): counts[...] = rhs * counts elif isinstance(obj, self.SimpleCounterTypes): self.write_counter[obj] *= rhs else: msg = f"Unsupported type {type(obj)}." raise TypeError(msg) return self
[docs] def read_objects(self, types): return set(filter(lambda x: isinstance(x, types), self.read_counter.keys()))
[docs] def written_objects(self, types): return set(filter(lambda x: isinstance(x, types), self.read_counter.keys()))
@property def fields(self): return set(self.read_objects(DiscreteScalarFieldView)).update( self.written_objects(DiscreteScalarFieldView) ) @property def arrays(self): return set(self.read_objects(SymbolicArray)).update( self.written_objects(SymbolicArray) ) @property def buffers(self): return set(self.read_objects(SymbolicBuffer)).update( self.written_objects(SymbolicBuffer) )
[docs] class SymbolicExpressionInfo: """Helper class store information about parsed symbolic expressions.""" def __new__( cls, name, exprs, dt=None, dt_coeff=None, compute_resolution=None, **kwds ): return super().__new__(cls, **kwds) def __init__( self, name, exprs, dt=None, dt_coeff=None, compute_resolution=None, **kwds ): super().__init__(**kwds) self.name = name self.exprs = exprs self.kind = self.check_expressions(exprs) # continuous part self.domain = None self.input_arrays = SortedDict() self.output_arrays = SortedDict() self.input_buffers = SortedDict() self.output_buffers = SortedDict() self.input_fields = SortedDict() self.output_fields = SortedDict() self.input_params = SortedDict() self.output_params = SortedDict() self.scalars = SortedDict() self.is_volatile = set() self.direction = None self.has_direction = None if self.kind is SymbolicExpressionKind.TIME_INTEGRATE: if not isinstance(dt, ScalarParameter) or not isinstance(dt_coeff, float): msg = "Symbolic expressions of kind TIME_INTEGRATE require two extra parameters:" msg += "\n *dt: ScalarParameter, got dt={}." msg += "\n *dt_coeff: float, got dt_coeff={}." msg += "\n Please give simulation timestep as an input parameter of operator {}." msg = msg.format(dt, dt_coeff, name) raise RuntimeError(msg) self.input_params[dt.name] = dt self.dt = dt self.dt_coeff = dt_coeff # field requirements part self.min_ghosts = None self.min_ghosts_per_components = None # discrete part self.dexprs = None self.input_dfields = None self.output_dfields = None self.inout_dfields = None self.discretization_info = None self.stencils = None self.tmp_vars = None if compute_resolution is None: self.compute_resolution = None self._dim = None else: compute_resolution = to_tuple(compute_resolution) check_instance(compute_resolution, tuple, values=int) self._dim = len(compute_resolution) self.compute_resolution = compute_resolution def _is_discretized(self): """Return true if the SymbolicExpressionInfo was discretized.""" return self.dexprs is not None def _get_dim(self): """Shortcut to domain dimension.""" assert self._dim is not None return self._dim def _get_fields(self): """Return input and output fields.""" fields = {k: v for (k, v) in self.input_fields.items()} fields.update(self.output_fields) return fields def _get_params(self): """Return input and output fields.""" fields = {k: v for (k, v) in self.input_params.items()} fields.update(self.output_params) return fields @property def max_granularity(self): return self.dim - 1 dim = property(_get_dim) fields = property(_get_fields) params = property(_get_params) is_discretized = property(_is_discretized)
[docs] def check_expressions(self, exprs): return SymbolicExpressionParser.check_expressions(exprs)
[docs] def check_field(self, field): """ Check if given continuous field is compatible with previously parsed fields and arrays. """ check_instance(field, Field) if self.domain is None: self.domain = field.domain if self._dim is None: self._dim = field.domain.dim elif self._dim != field.domain.dim: msg = "Dimension mismatch between field domain dimension and array dimension, " msg += f"got {self._dim} and {field.domain.dim}." raise ValueError(msg) elif field.domain is not self.domain: msg = "Domain mismatch for field {}:\n{}\nReference domain was:\n{}." msg = msg.format(field.name, field.domain, self.domain) raise ValueError(msg)
[docs] def check_array(self, array): """ Check if given symbolic array is compatible with previously parsed fields and arrays. """ check_instance(array, SymbolicArray) dim = array.dim if self.domain is not None: if self.domain.dim != dim: msg = "Dimension mismatch between field domain dimension and array dimension, " msg += f"got {self.domain.dim} and {dim}." raise ValueError(msg) elif self._dim is not None: if self._dim != dim: msg = "Dimension mismatch between arrays, got {} and {}." msg = msg.format(self._dim, dim) raise ValueError(msg) else: self._dim = dim
[docs] def extract_obj_requirements(self): (field_requirements, array_requirements) = ( SymbolicExpressionParser.extract_obj_requirements(self) ) self.field_requirements = field_requirements self.array_requirements = array_requirements self.has_direction = self.direction is not None self.direction = first_not_None(self.direction, 0)
[docs] def discretize_expressions( self, input_dfields, output_dfields, force_symbolic_axes ): check_instance(input_dfields, dict, keys=Field, values=DiscreteScalarFieldView) check_instance(output_dfields, dict, keys=Field, values=DiscreteScalarFieldView) assert len(set(self.input_fields.keys()) - set(input_dfields.keys())) == 0 assert len(set(self.output_fields.keys()) - set(output_dfields.keys())) == 0 self.input_dfields = { k: v for (k, v) in input_dfields.items() if (k in self.input_fields) } self.output_dfields = { k: v for (k, v) in output_dfields.items() if (k in self.output_fields) } self.inout_dfields = { k: v for (k, v) in self.output_dfields.items() if ( (k in self.input_dfields) and (self.input_dfields[k].dfield is v.dfield) ) } self.stencils = SortedDict() dfields = tuple(input_dfields.values()) + tuple(output_dfields.values()) if force_symbolic_axes is not None: if isinstance(force_symbolic_axes, tuple): axes = force_symbolic_axes else: axes = None elif dfields: axes = dfields[0].tstate.axes for dfield in dfields: if dfield.tstate.axes != axes: msg = "Discrete field {} has a topology state axes mismatch {} " msg += "with reference axes {}." msg = msg.format(dfield.name, dfield.tstate.axes, axes) raise RuntimeError(msg) else: msg = "No discrete fields found in custom symbolic operator." raise RuntimeError(msg) self.axes = axes SymbolicExpressionParser.discretize_expressions(self) self.check_dfield_sizes()
[docs] def setup_expressions(self, work): SymbolicExpressionParser.setup_expressions(self, work)
[docs] def check_dfield_sizes(self): dfields = set(self.input_dfields.values()).union(self.output_dfields.values()) if len(dfields) > 0: dfield0 = next(iter(dfields)) compute_resolution = first_not_None( self.compute_resolution, dfield0.compute_resolution ) for dfield in dfields: if (dfield.compute_resolution != compute_resolution).any(): msg = "Mismatching compute resolution {}::{} vs {}::{}." msg = msg.format( dfield.name, dfield.compute_resolution, dfield0.name, dfield0.compute_resolution, ) raise ValueError(msg) compute_resolution = tuple(compute_resolution) self.compute_resolution = compute_resolution
[docs] def check_arrays(self): compute_resolution = self.compute_resolution arrays = set(self.input_arrays.values()).union(self.output_arrays.values()) for a in arrays: if not a.is_bound: msg = "FATAL ERROR: {}::{} has not been bound to any memory " msg += "prior to setup." msg = msg.format(type(a).__name__, a.name) raise RuntimeError(msg) dim = a.dim shape = a.shape if len(shape) != dim: msg = "FATAL ERROR: {}::{} array shape does not match array dimension." msg = msg.format(type(a).__name__, a.name) raise RuntimeError(msg) if compute_resolution is None: compute_resolution = shape elif not npw.array_equal(shape, compute_resolution): msg = "FATAL ERROR: {}::{} array shape {} does not comply with determined " msg += "compute resolution {}." msg = msg.format(type(a).__name__, a.name, shape, compute_resolution) raise RuntimeError(msg) if compute_resolution is None: msg = "FATAL ERROR: Something went wrong while determining compute_resolution." raise RuntimeError(msg) self.compute_resolution = tuple(compute_resolution)
[docs] def check_buffers(self): buffers = set(self.input_buffers.values()).union(self.output_buffers.values()) for b in buffers: if not b.is_bound: msg = "FATAL ERROR: {}::{} has not been bound to any memory " msg += "prior to setup." msg = msg.format(type(b).__name__, b.name) raise RuntimeError(msg)
[docs] def determine_direction(self, *variables): assert self.domain is not None direction = self.direction for var in variables: if var in space_symbols: if direction is None: direction = space_symbols.index(var) else: xd = space_symbols[direction] if xd != var: msg = "Expression already contained a derivative with respect to {} (direction {})." msg += "\nFound a new derivative direction which is not compatible with the current " msg += "one." msg += "\nCannot differentiate with respect to {}." msg = msg.format(xd, direction, var) raise ValueError(msg) self.direction = direction
def __str__(self): msg = """ ::SymbolicExpressionInfo:: expression kind: {} continuous expressions:{} input_fields: {} output_fields: {} input_arrays: {} output_arrays: {} input_buffers: {} output_buffers: {} input_params: {} output_params: {} discretizations:{} """.format( self.kind, "\n" + "\n".join(f" {i}/ {e}" for i, e in enumerate(self.exprs)), ( ", ".join(f"{f.name}" for f in self.input_fields.keys()) if self.input_fields else "none" ), ( ", ".join(f"{f.name}" for f in self.output_fields.keys()) if self.output_fields else "none" ), ( ", ".join(f"{f}" for f in self.input_arrays.keys()) if self.input_arrays else "none" ), ( ", ".join(f"{f}" for f in self.output_arrays.keys()) if self.output_arrays else "none" ), ( ", ".join(f"{f}" for f in self.input_buffers.keys()) if self.input_buffers else "none" ), ( ", ".join(f"{f}" for f in self.output_buffers.keys()) if self.output_buffers else "none" ), ( ", ".join(f"{p}" for p in self.input_params.keys()) if self.input_params else "none" ), ( ", ".join(f"{p}" for p in self.output_params.keys()) if self.output_params else "none" ), "\n" + "\n".join( " {}: {}".format( f.name, (d.short_description() if isinstance(d, TopologyView) else d), ) for (f, d) in self.fields.items() ), ) if self.min_ghosts: msg += """ min_ghosts_per_components:{} min_ghosts:{} """.format( "\n" + "\n".join( " {}/ [{}]".format(f.name, ", ".join(str(x) for x in gpc)) for f, gpc in self.min_ghosts_per_components.items() ), "\n" + "\n".join( " {}/ [{}]".format(f.name, ", ".join(str(x) for x in g)) for f, g in self.min_ghosts.items() ), ) if self.is_discretized: msg += """ discretized expressions:{} read_counter: {} write_counter: {} """.format( "\n" + "\n".join(f" {i}/ {e}" for i, e in enumerate(self.dexprs)), ( ", ".join( f"{f.name}{self.discretization_info.read_counter[f.dfield]}" for f in self.input_dfields.values() ) if self.input_dfields else "none" ), ( ", ".join( f"{f.name}{self.discretization_info.write_counter[f.dfield]}" for f in self.output_dfields.values() ) if self.output_dfields else "none" ), ) return msg def _get_fields(self): """Return all fields as a set.""" return set(self._input_fields.values()).update(self._output_fields.values()) @property def extracted_exprs(self): return SymbolicExpressionParser.extract_expressions(self.exprs) @property def extracted_dexprs(self): return SymbolicExpressionParser.extract_expressions(self.dexprs)
[docs] class SymbolicExpressionParser: """Helper class to parse symbolic expressions."""
[docs] @classmethod def extract_expressions(cls, exprs): E = () for e in exprs: if isinstance(e, CodeSection): E += cls.extract_expressions(e.args) else: E += (e,) return E
[docs] @classmethod def check_expressions(cls, exprs): kind = None fields = SortedDict() arrays = SortedDict() exprs = tuple( filter( lambda e: isinstance(e, ValidExpressions), cls.extract_expressions(exprs), ) ) for expr in exprs: check_instance(expr, ValidExpressions) lhs = expr.args[0] field = None array = None if isinstance(lhs, TmpScalar): continue elif isinstance(lhs, IndexedBuffer): pass elif isinstance(lhs, (AppliedSymbolicField, SymbolicArray)): if kind is None: kind = SymbolicExpressionKind.AFFECT elif kind is not SymbolicExpressionKind.AFFECT: msg = "Symbolic expression kind was set to be {} but found " msg += " an expression that is of kind {}.\n expr: {}" msg = msg.format(kind, SymbolicExpressionKind.AFFECT, expr) raise ValueError(msg) if isinstance(lhs, AppliedSymbolicField): field = lhs else: array = lhs elif isinstance(lhs, sm.Derivative): _vars = get_derivative_variables(lhs) _t = time_symbol unique_vars = set(_vars) if isinstance(lhs.args[0], AppliedSymbolicField): field = lhs.args[0] elif isinstance(lhs.args[0], SymbolicArray): msg = "Assignment LHS cannot be a derivative of a SymbolicArray " msg += "because ghosts are not handled as symbolic Fields." raise TypeError(msg) else: msg = "Assignment LHS cannot be a derivative of a {}." msg = msg.format(type(lhs)) raise TypeError(msg) if (_t not in unique_vars) or len(unique_vars) != 1: msg = "Assignment LHS can only be a derivative of time {}, got {}." msg = msg.format(_t, ", ".join(str(x) for x in unique_vars)) raise TypeError(msg) if kind is None: kind = SymbolicExpressionKind.TIME_INTEGRATE elif kind is not SymbolicExpressionKind.TIME_INTEGRATE: msg = "Symbolic expression kind was set to be {} but found " msg += " an expression that is of kind {}.\n expr: {}" msg = msg.format(kind, SymbolicExpressionKind.TIME_INTEGRATE, expr) raise ValueError(msg) else: msg = f"Assignment LHS cannot be of type {type(lhs)}." raise TypeError(msg) if field is not None: assert isinstance(field, AppliedSymbolicField) index = field.index field = field.field key = (field, index) if key in fields: msg = "Field {} was already written by expression\n" msg += "{}\ncannot write it again in expression\n" msg += "{}\nFATAL ERROR: Invalid expressions." msg = msg.format(field.name, fields[key], expr) raise ValueError(msg) fields[key] = expr if array is not None: assert isinstance(array, SymbolicArray) key = array if key in arrays: msg = "Array {} was already written by expression\n" msg += "{}\ncannot write it again in expression\n" msg += "{}\nFATAL ERROR: Invalid expressions." msg = msg.format(array.name, arrays[key], expr) raise ValueError(msg) arrays[key] = expr if kind is None: kind = SymbolicExpressionKind.AFFECT return kind
[docs] @classmethod def parse(cls, name, variables, *exprs, **kwds): preferred_direction = first_not_None(kwds.pop("preferred_direction", None), 0) info = SymbolicExpressionInfo(name, exprs, **kwds) for expr in cls.extract_expressions(exprs): cls.parse_one(variables, info, expr) if info._dim is None: msg = "\n\nFATAL ERROR: Neither SymbolicFields nor SymbolicArrays were present in parsed " msg += "symbolic expressions and compute_resolution has not been specified." msg += "\nAt least one is needed to deduce the shape of the compute kernel." msg += "\n" msg += "\nExpressions were:" for i, e in enumerate(exprs): msg += f"\n {i:2>}/ {e}" msg += "\n" raise RuntimeError(msg) if info.direction is None: info.direction = preferred_direction return info
[docs] @classmethod def parse_one(cls, variables, info, expr): if isinstance(expr, Assignment): cls.parse_assignment(variables, info, *expr.args) else: try: cls.parse_subexpr(variables, info, expr) except: msg = "Failed to parse symbolic expression type {}." print() print(msg.format(type(expr))) print() raise
[docs] @classmethod def parse_assignment(cls, variables, info, lhs, rhs): if isinstance( lhs, (AppliedSymbolicField, SymbolicArray, IndexedBuffer, TmpScalar) ): cls.write(variables, info, lhs) cls.parse_subexpr(variables, info, rhs) if isinstance(lhs, IndexedBuffer): cls.parse_subexpr(lhs.index, info, rhs) elif isinstance(lhs, sm.Derivative): f = lhs.args[0] cls.read(variables, info, f) cls.write(variables, info, f) cls.parse_subexpr(variables, info, rhs) else: msg = "Unknown expression type {}.\n __mro__ = {}\nExpression is: {}\n" msg = msg.format(type(lhs), type(lhs).__mro__, lhs) raise NotImplementedError(msg)
[docs] @classmethod def parse_subexpr(cls, variables, info, expr): if isinstance(expr, npw.ndarray): assert expr.ndim == 0, expr expr = expr.tolist() if isinstance(expr, (str, int, float, complex, npw.number)): return elif isinstance( expr, (AppliedSymbolicField, SymbolicScalarParameter, SymbolicArray) ): cls.read(variables, info, expr) elif isinstance(expr, Cast): cls.parse_subexpr(variables, info, expr.expr) elif isinstance(expr, MutexLock): var = expr.mutexes cls.read(variables, info, var) cls.write(variables, info, var) info.is_volatile.add(var.name) elif isinstance(expr, MutexUnlock): var = expr.mutexes cls.write(variables, info, var) info.is_volatile.add(var.name) elif isinstance(expr, sm.Derivative): dvars = get_derivative_variables(expr) info.determine_direction(*dvars) cls.parse_subexpr(variables, info, expr.args[0]) elif isinstance(expr, (sm.Expr, sm.Rel)): for e in expr.args: cls.parse_subexpr(variables, info, e) else: msg = "Unknown expression type {}.\n __mro__ = {}\nExpression is: {}\n" msg = msg.format(type(expr), type(expr).__mro__, expr) raise NotImplementedError(msg)
[docs] @classmethod def write(cls, variables, info, var): if isinstance(var, TmpScalar): info.scalars[var.varname] = var elif isinstance(var, IndexedBuffer): cls.write(variables, info, var.indexed_object) elif isinstance(var, SymbolicArray): array = var info.check_array(array) if array.name not in info.output_arrays: info.output_arrays[array.name] = array else: assert info.output_arrays[array.name] is array elif isinstance(var, SymbolicBuffer): buf = var if buf.name not in info.output_buffers: info.output_buffers[buf.name] = buf else: assert info.output_buffers[buf.name] is buf elif isinstance(var, AppliedSymbolicField): field = var.field info.check_field(field) if field not in variables: msg = ( "Field {} is written but no discretization was given in variables." ) msg = msg.format(field.name) raise ValueError(msg) if field not in info.output_fields: info.output_fields[field] = variables[field] elif isinstance(var, SymbolicScalarParameter): param = var.parameter pname = param.name if param.const: msg = "FATAL ERROR: Cannot assign value to constant parameter {}." msg = msg.format(pname) raise ValueError(msg) elif (pname in info.output_params) and ( info.output_params[pname] is not param ): msg = "Incompatible parameter names {}." msg = msg.format(pname) raise ValueError(msg) info.output_params[pname] = param else: msg = "Unknown written variable type {}.\n __mro__ = {}\n" msg = msg.format(type(var), type(var).__mro__) raise NotImplementedError(msg)
[docs] @classmethod def read(cls, variables, info, var, offset=None): if isinstance(var, IndexedBuffer): cls.read(variables, info, var.indexed_object) elif isinstance(var, AppliedSymbolicField): field = var.field info.check_field(field) if field not in variables: msg = "Field {} is read but no discretization was given in variables." msg = msg.format(field.name) raise ValueError(msg) if field not in info.input_fields: info.input_fields[field] = variables[field] elif isinstance(var, SymbolicArray): array = var info.check_array(array) if array not in info.input_arrays: info.input_arrays[array.name] = array else: assert info.input_arrays[array.name] is array elif isinstance(var, SymbolicBuffer): buf = var if buf not in info.input_buffers: info.input_buffers[buf.name] = buf else: assert info.input_buffers[buf.name] is buf elif isinstance(var, SymbolicScalarParameter): param = var.parameter if param.name in info.input_params: assert ( info.input_params[param.name] is param ), "Incompatible parameter names." else: info.input_params[param.name] = param else: msg = "Unknown read variable type {}.\n __mro__ = {}\n" msg = msg.format(type(var), type(var).__mro__) raise NotImplementedError(msg)
[docs] @classmethod def extract_obj_requirements(cls, info): check_instance(info, SymbolicExpressionInfo) direction = info.direction min_ghosts_per_expr = SortedDict() array_requirements = SortedDict() field_requirements = SortedDict() updated_fields = SortedDict() updated_arrays = SortedDict() def check_tmp(i, expr, unknown_tmp): if unknown_tmp: msg = "\nError during temporary scalars expansion pass." msg += "\nSymbolic expressions were:" msg += "\n " + "\n ".join(str(e) for e in info.exprs) msg += "\nExpression {}/ {} use temporary scalars that have not been " msg += "defined yet:{}" msg = msg.format( i, expr, "\n *" + "\n *".join(tmp.name for tmp in unknown_tmp) ) raise ValueError(msg) # expand tmp scalars into expressions RHS wexprs = () oexprs = () tmp_map = SortedDict() for i, expr in enumerate(info.extracted_exprs): if isinstance(expr, Assignment): (lhs, rhs) = expr.args try: rhs_tmp = set( filter(lambda v: isinstance(v, TmpScalar), rhs.free_symbols) ) except AttributeError: rhs_tmp = set() unknown_tmp = rhs_tmp - set(tmp_map.keys()) check_tmp(i, expr, unknown_tmp) try: rhs = rhs.xreplace(tmp_map) except AttributeError: pass if isinstance(lhs, TmpScalar): tmp_map[lhs] = rhs else: wexprs += (expr.func(lhs, rhs),) else: args = expr.args args_tmp = set() _args = () for a in args: try: atmp = set( filter(lambda v: isinstance(v, TmpScalar), a.free_symbols) ) args_tmp.update(atmp) except AttributeError: pass unknown_tmp = args_tmp - set(tmp_map.keys()) check_tmp(i, expr, unknown_tmp) try: a = a.xreplace(tmp_map) except AttributeError: pass _args += (a,) expr = expr.func(*_args) oexprs += (expr,) for i, expr in enumerate(wexprs): obj_reqs = cls._extract_obj_requirements(info, expr) lhs = expr.args[0] if isinstance(lhs, sm.Derivative): lhs = lhs.args[0] if isinstance(lhs, AppliedSymbolicField): (field, index) = lhs.field, lhs.index updated_fields[i] = (field, index, f"{field.name}_{index}") elif isinstance(lhs, SymbolicArray): array = lhs updated_arrays[i] = (array, array.name) elif isinstance(lhs, IndexedBuffer): pass else: msg = f"Unsupported type {type(lhs)}." raise TypeError(msg) min_ghosts_expr_i = SortedDict() for obj, reqs in obj_reqs.items(): if isinstance(obj, tuple) and isinstance(obj[0], Field): if obj in field_requirements: field_requirements[obj].update_requirements(reqs) else: field_requirements[obj] = reqs k = f"{obj[0].name}_{obj[1]}" elif isinstance(obj, SymbolicArray): if obj in array_requirements: array_requirements[obj].update_requirements(reqs) else: array_requirements[obj] = reqs k = obj.name else: msg = f"Unsupported type {type(obj).__mro__}." raise TypeError(msg) v = reqs.min_ghosts[-direction - 1] min_ghosts_expr_i[k] = v min_ghosts_per_expr[i] = min_ghosts_expr_i lhs_fields = {v[2]: k for (k, v) in updated_fields.items()} lhs_arrays = {v[1]: k for (k, v) in updated_arrays.items()} lhs_objects = lhs_fields.copy() lhs_objects.update(lhs_arrays) rhs_fields = tuple(f"{k[0].name}_{k[1]}" for k in field_requirements.keys()) rhs_arrays = tuple(k.name for k in array_requirements.keys()) rhs_objects = rhs_fields + rhs_arrays ro_fields = set(rhs_fields).difference(lhs_fields) ro_arrays = set(rhs_arrays).difference(lhs_arrays) ro_objects = ro_fields.union(ro_arrays) nlhsfields = len(lhs_fields) nrhsfields = len(rhs_fields) nlhsarrays = len(lhs_arrays) nrhsarrays = len(rhs_arrays) nlhsobjects = nlhsfields + nlhsarrays nrhsobjects = nrhsfields + nrhsarrays # compute ghosts per integration steps if info.kind == SymbolicExpressionKind.TIME_INTEGRATE: nsteps = info.time_integrator.stages else: nsteps = 1 info.nsteps = nsteps info.nlhsfields = nlhsfields info.nrhsfields = nrhsfields info.nlhsarrays = nlhsarrays info.nrhsarrays = nrhsarrays info.nlhsobjects = nlhsobjects info.nrhsobjects = nrhsobjects # compute ghosts per expression matrix all_objects = lhs_objects.copy() for i, fname in enumerate(ro_objects, start=nlhsobjects): all_objects[fname] = i nobjects = len(all_objects) assert nobjects == len(ro_fields) + len(ro_arrays) + len(lhs_fields) + len( lhs_arrays ) info.nobjects = nobjects expr_ghost_map = npw.int_zeros(shape=(nlhsobjects, nobjects)) for fi_name, i in lhs_objects.items(): min_ghosts = min_ghosts_per_expr[i] for fj_name, min_ghost in min_ghosts.items(): assert fj_name in all_objects, fj_name j = all_objects[fj_name] expr_ghost_map[i, j] = min_ghost min_ghosts_per_step = npw.int_zeros(shape=(nsteps, nlhsobjects)) if nlhsobjects: G_f = expr_ghost_map[:, :nlhsobjects] min_ghosts_per_step[0, :] = npw.max(G_f, axis=0) for s in range(1, nsteps): min_ghosts_per_step[s] = npw.max( G_f + min_ghosts_per_step[s - 1][:, None], axis=0 ) min_ghosts_lhs = min_ghosts_per_step[nsteps - 1] if nsteps > 1: min_ghosts_rhs = npw.max( expr_ghost_map[:, nlhsobjects:] + min_ghosts_per_step[nsteps - 2][:, None], axis=0, ) else: min_ghosts_rhs = npw.max(expr_ghost_map[:, nlhsobjects:], axis=0) min_ghosts = min_ghosts_lhs.tolist() + min_ghosts_rhs.tolist() else: min_ghosts = npw.int_zeros(shape=(nobjects,)) lhs_objects = {v: k for (k, v) in lhs_objects.items()} lhs_objects = tuple(lhs_objects[i] for i in range(nlhsobjects)) all_objects = {v: k for (k, v) in all_objects.items()} all_objects = tuple(all_objects[i] for i in range(nobjects)) info.expr_ghost_map = expr_ghost_map info.min_ghosts_per_integration_step = min_ghosts_per_step info.min_ghosts_per_field_name = dict(zip(all_objects, min_ghosts)) info.lhs_object_names = lhs_objects info.rhs_object_names = rhs_objects info.all_object_names = all_objects return field_requirements, array_requirements
@classmethod def _extract_obj_requirements(cls, info, expr): if isinstance(expr, npw.ndarray): assert expr.ndim == 0 expr = expr.tolist() if isinstance( expr, (int, sm.Integer, float, complex, sm.Rational, sm.Float, npw.number) ): return {} elif isinstance(expr, Cast): return cls._extract_obj_requirements(info, expr.expr) elif isinstance(expr, SymbolicArray): return {expr: expr.new_requirements()} elif isinstance(expr, AppliedSymbolicField): field = expr.field index = expr.index return { (field, index): DiscreteFieldRequirements( operator=None, variables=None, field=field, _register=False ) } elif isinstance(expr, str): return {} elif isinstance(expr, sm.Derivative): dexpr = expr.args[0] dvars = get_derivative_variables(expr) unique_dvars = set(dvars) invalid_dvars = unique_dvars - set(space_symbols) - {time_symbol} if invalid_dvars: msg = "Cannot differentiate with respect to variable(s) {}." msg = msg.format(", ".join(str(x) for x in invalid_dvars)) msg += "\nOnly allowed variables are: {}".format( " ,".join(str(x) for x in space_symbols) ) raise ValueError(msg) direction = info.direction if direction is not None: xd = space_symbols[direction] if unique_dvars - {xd}: msg = ( "Expression already contained a derivative with respect to {} " ) msg += "(direction {}, {}-axis)." msg += "\nFound a new derivative direction which is not compatible " msg += "with the current one." msg += "\nCannot differentiate with respect to {}." msg = msg.format( xd, direction, DirectionLabels[direction], ", ".join(str(x) for x in (unique_dvars - {xd})), ) raise RuntimeError(msg) else: if len(unique_dvars) > 1: msg = "Cannot differentiate on different variables at a time: {}" msg = msg.format(", ".join(str(x) for x in unique_dvars)) raise ValueError(msg) xd = dvars[0] assert xd in space_symbols, xd direction = space_symbols.index(xd) info.direction = direction derivative = len(dvars) order = info.space_discretization dxd = dspace_symbols[direction] assert order > 0, order assert order % 2 == 0, order csg = CenteredStencilGenerator() csg.configure(dim=1, dtype=MPQ, derivative=derivative) stencil = csg.generate_exact_stencil(order=order) min_ghosts = max(stencil.L, stencil.R) obj_reqs = cls._extract_obj_requirements(info, dexpr) for obj, req in obj_reqs.items(): req.min_ghosts[-1 - direction] += min_ghosts return obj_reqs elif isinstance(expr, Assignment): lhs, rhs = expr.args if isinstance(lhs, sm.Derivative): assert len(lhs.args) == 2 try: assert lhs.args[1] == time_symbol except: # sympy version >= 1.2 assert lhs.args[1][0] == time_symbol assert lhs.args[1][1] == 1 lhs = lhs.args[0] freqs = cls._extract_obj_requirements(info, rhs) return freqs elif isinstance(expr, (sm.Expr, sm.Rel)): obj_requirements = SortedDict() for e in expr.args: obj_reqs = cls._extract_obj_requirements(info, e) for obj, reqs in obj_reqs.items(): if obj in obj_requirements: obj_requirements[obj].update_requirements(reqs) else: obj_requirements[obj] = reqs return obj_requirements else: msg = "Unknown expression type {}.\n __mro__ = {}\n" msg = msg.format(type(expr), type(expr).__mro__) raise NotImplementedError(msg)
[docs] @classmethod def discretize_expressions(cls, info): check_instance(info, SymbolicExpressionInfo) dexprs = () discretization_info = ExprDiscretizationInfo() for expr in info.exprs: dexpr, di = cls.discretize_one(info, expr) dexprs += (dexpr,) discretization_info.update(di) info.dexprs = dexprs info.discretization_info = discretization_info
[docs] @classmethod def discretize_one(cls, info, expr): return cls.discretize_subexpr(info, expr)
[docs] @classmethod def discretize_assignment(cls, info, expr): msg = "Unsupported, use Assignment instead of {}." msg = msg.format(type(expr).__name__) assert not isinstance(expr, AugmentedAssignment), msg lhs, rhs = expr.args rhs, di = cls.discretize_subexpr(info, rhs) if isinstance( lhs, ( AppliedSymbolicField, SymbolicArray, IndexedBuffer, TmpScalar, ), ): func = expr.func elif isinstance(lhs, sm.Derivative): assert isinstance(lhs.args[0], AppliedSymbolicField) assert len(lhs.args) == 2 try: assert lhs.args[1] == time_symbol except: # sympy version >= 1.2 assert lhs.args[1][0] == time_symbol assert lhs.args[1][1] == 1 lhs = lhs.args[0] assert expr.func is Assignment func = lambda *args: TimeIntegrate(info.time_integrator, *args) dfield = info.input_dfields[lhs.field] cls.read_discrete(info, lhs, dfield, di) else: msg = "Invalid symbolic assignment lhs type {}." msg = msg.format(type(lhs)) raise NotImplementedError(msg) if isinstance(lhs, AppliedSymbolicField): field, index, indexed_field = lhs.field, lhs.index, lhs.indexed_field dfield = info.output_dfields[field] lhs = dfield.s[index] check_instance(lhs, SymbolicDiscreteField) cls.write_discrete(info, lhs, dfield, di) elif isinstance(lhs, IndexedBuffer): di.write(lhs.indexed_object) index, edi = cls.discretize_subexpr(info, lhs.index) di.update(edi) lhs = lhs.func(lhs.indexed_object, index) elif isinstance(lhs, TmpScalar): info.scalars[lhs.varname] = lhs else: di.write(lhs) new_expr = func(lhs, rhs) return new_expr, di
[docs] @classmethod def write_discrete(cls, info, expr, dfield, di): index = expr.index di.write(dfield, index, 1)
[docs] @classmethod def read_discrete(cls, info, expr, dfield, di): index = expr.index di.read(dfield, index, 1)
[docs] @classmethod def discretize_subexpr(cls, info, expr): di = ExprDiscretizationInfo() if isinstance(expr, (list, tuple, set, npw.ndarray)): texpr = type(expr) E = () for e in expr: e, edi = cls.discretize_subexpr(info, e) di.update(edi) E += (e,) if texpr in (list, tuple, set): expr = texpr(E) else: expr = list(E) return expr, di elif isinstance(expr, (int, float, complex, npw.number)): return expr, di elif cls.should_transpose_expr(info, expr): expr = cls.transpose_expr(info, expr) return expr, di elif isinstance(expr, Assignment): return cls.discretize_assignment(info, expr) elif isinstance(expr, Cast): e, edi = cls.discretize_subexpr(info, expr.expr) di.update(edi) return expr.func(e, *expr.args[1:]), di elif isinstance(expr, MutexLock): var = expr.mutexes di.read(var) di.write(var) args, edi = cls.discretize_subexpr(info, expr.args[1:]) expr = expr.func(var, *args) di.update(edi) return expr, di elif isinstance(expr, MutexUnlock): var = expr.mutexes di.write(var) args, edi = cls.discretize_subexpr(info, expr.args[1:]) expr = expr.func(var, *args) di.update(edi) return expr, di elif isinstance(expr, TmpScalar): return expr, di elif isinstance(expr, str): return expr, di elif isinstance(expr, SymbolicScalarParameter): di.push_parameters(expr.parameter) return expr, di elif isinstance(expr, (SymbolicArray, SymbolicBuffer)): di.read(expr) return expr, di elif isinstance(expr, AppliedSymbolicField): indexed_field = expr.indexed_field index, field = expr.index, expr.field dfield = info.input_dfields[field] cls.read_discrete(info, expr, dfield, di) return dfield.s[index], di elif isinstance(expr, sm.Derivative): dexpr = expr.args[0] dvars = get_derivative_variables(expr) unique_dvars = set(dvars) invalid_dvars = unique_dvars - set(space_symbols) direction = info.direction xd = dvars[0] derivative = len(dvars) order = info.space_discretization assert not invalid_dvars assert xd in space_symbols, xd assert xd == space_symbols[direction] assert len(unique_dvars) == 1 assert order > 0, order assert order % 2 == 0, order csg = CenteredStencilGenerator() csg.configure(dim=1, dtype=MPQ, derivative=derivative) stencil = csg.generate_exact_stencil(order=order) dexpr, di = cls.discretize_subexpr(info, dexpr) di += stencil.non_zero_coefficients() - 1 expr = ApplyStencil(dexpr, stencil) info.stencils[stencil] = di.copy() return expr, di elif isinstance(expr, (sm.Expr, sm.Rel)): new_args = () for e in expr.args: arg, edi = cls.discretize_subexpr(info, e) di.update(edi) new_args += (arg,) if new_args: try: expr = expr.func(*new_args) except: msg = "Failed to build a {} from arguments {}." msg = msg.format(expr.func, new_args) print() print(msg) print() raise return expr, di else: return expr, di else: msg = "Unknown expression type {}.\n __mro__ = {}\n" msg = msg.format(type(expr), type(expr).__mro__) raise NotImplementedError(msg)
[docs] @classmethod def setup_expressions(cls, info, work): check_instance(info, SymbolicExpressionInfo) for dexpr in info.dexprs: cls.setup_one(dexpr, work)
[docs] @classmethod def setup_one(cls, dexpr, work): for atom in dexpr.atoms(SetupExprI): atom.setup(work)
[docs] @classmethod def transposable_expressions(cls): return ( space_symbols, dspace_symbols, local_indices_symbols, global_indices_symbols, )
[docs] @classmethod def should_transpose_expr(cls, info, expr): return any(expr in te for te in cls.transposable_expressions())
[docs] @classmethod def transpose_expr(cls, info, expr): axes = info.axes if axes is None: return expr dim = len(axes) assert isinstance(axes, tuple) assert cls.should_transpose_expr(info, expr) assert len(set(axes)) == dim symbols = None for te in cls.transposable_expressions(): if expr in te: symbols = te[:dim] break assert symbols is not None i = symbols.index(expr) return symbols[axes[i]]
[docs] class CustomSymbolicOperatorBase(DirectionalOperatorBase, metaclass=ABCMeta): """ Common implementation interface for custom symbolic (code generated) operators. """ __default_method = { ComputeGranularity: 0, SpaceDiscretization: 2, TimeIntegrator: Euler, MultiScaleInterpolation: Interpolation.LINEAR, } __available_methods = { ComputeGranularity: InstanceOf(int), SpaceDiscretization: InstanceOf(int), TimeIntegrator: InstanceOf(ExplicitRungeKutta), MultiScaleInterpolation: Interpolation.LINEAR, }
[docs] @classmethod def default_method(cls): dm = super().default_method() dm.update(cls.__default_method) return dm
[docs] @classmethod def available_methods(cls): am = super().available_methods() am.update(cls.__available_methods) return am
[docs] @debug def handle_method(self, method): super().handle_method(method) cr = method.pop(ComputeGranularity) space_discretization = method.pop(SpaceDiscretization) time_integrator = method.pop(TimeIntegrator) interpolation = method.pop(MultiScaleInterpolation) assert 0 <= cr <= self.expr_info.max_granularity, cr assert 2 <= space_discretization, space_discretization assert space_discretization % 2 == 0, space_discretization self._expr_info.compute_granularity = cr self._expr_info.time_integrator = time_integrator self._expr_info.interpolation = interpolation self._expr_info.space_discretization = space_discretization
@debug def __new__( cls, name, exprs, variables, splitting_direction=None, splitting_dim=None, dt_coeff=None, dt=None, time=None, **kwds, ): return super().__new__( cls, name=name, input_fields=None, output_fields=None, input_params=None, output_params=None, input_tensor_fields=None, output_tensor_fields=None, splitting_direction=splitting_direction, splitting_dim=splitting_dim, dt_coeff=dt_coeff, **kwds, ) @debug def __init__( self, name, exprs, variables, splitting_direction=None, splitting_dim=None, dt_coeff=None, dt=None, time=None, **kwds, ): """ Initialize a CustomSymbolicOperatorBase. Expressions are parsed and input/output vars are extracted. Parameters ---------- exprs: array_like of valid hysop.symbolic.Expr Expressions that will generate code. Valid expressions are defined as hysop.operator.base.custom_symbolic_operator.ValidExpressions. variables: dict dictionary of fields as keys aned topologies as values. splitting_direction: int Expected direction of derivatives in given expression. splitting_dim: int Only used in directional splittings. dt_coeff: float Only used in directional splittings. dt: ScalarParameter Only used for integration. kwds: Base class keyword arguments. Notes ----- All input and output fields and parameters are directly extracted from expression analysis. """ check_instance(variables, dict, keys=Field, values=CartesianTopologyDescriptors) check_instance(exprs, tuple, values=ValidExpressions, minsize=1) check_instance(splitting_direction, int, allow_none=True) check_instance(splitting_dim, int, allow_none=True) check_instance(dt_coeff, float, allow_none=True) check_instance(dt, ScalarParameter, allow_none=True) if (splitting_dim is None) ^ (dt_coeff is None): msg = "splitting_dim and dt_coeff should be specified in the same time." raise ValueError(msg) dt_coeff = first_not_None(dt_coeff, 1.0) # Expand tensor fields to scalar fields scalar_variables = { sfield: topod for (tfield, topod) in variables.items() for sfield in tfield.fields } expr_info = SymbolicExpressionParser.parse( name, scalar_variables, *exprs, dt=dt, dt_coeff=dt_coeff, preferred_direction=splitting_direction, ) if splitting_direction is not None: assert expr_info.direction == splitting_direction splitting_direction = expr_info.direction splitting_dim = first_not_None(splitting_dim, expr_info.domain.dim) if expr_info.direction != splitting_direction: msg = "Direction mismatch, expression has derivative in direction {} but direction {} " msg += "has been specified." msg = msg.format(expr_info.direction, splitting_direction) raise ValueError(msg) input_fields = expr_info.input_fields output_fields = expr_info.output_fields input_params = set(expr_info.input_params.values()) output_params = set(expr_info.output_params.values()) input_tensor_fields = () output_tensor_fields = () for tfield in filter(lambda x: x.is_tensor, variables.keys()): if all((f in input_fields) for f in tfield.fields): input_tensor_fields += (tfield,) if all((f in output_fields) for f in tfield.fields): output_tensor_fields += (tfield,) super().__init__( name=name, input_fields=input_fields, output_fields=output_fields, input_params=input_params, output_params=output_params, input_tensor_fields=input_tensor_fields, output_tensor_fields=output_tensor_fields, splitting_direction=splitting_direction, splitting_dim=splitting_dim, dt_coeff=dt_coeff, **kwds, ) self._expr_info = expr_info def _get_expr_info(self): """Get information about parsed symbolic expressions.""" return self._expr_info expr_info = property(_get_expr_info)
[docs] @debug def get_field_requirements(self): """Extract field requirements from first expression parsing stage.""" requirements = super().get_field_requirements() expr_info = self.expr_info expr_info.extract_obj_requirements() dim = expr_info.domain.dim field_reqs = expr_info.field_requirements array_reqs = expr_info.array_requirements direction = expr_info.direction has_direction = expr_info.has_direction if has_direction: assert 0 <= direction < dim axes = TranspositionState[dim].filter_axes( lambda axes: (axes[-1] == dim - 1 - direction) ) axes = tuple(axes) min_ghosts_per_components = SortedDict() for fields, is_input, iter_requirements in zip( (self.input_fields, self.output_fields), (True, False), ( requirements.iter_input_requirements, requirements.iter_output_requirements, ), ): if not fields: continue for field, td, req in iter_requirements(): min_ghosts = npw.int_zeros(shape=(field.nb_components, field.dim)) if has_direction: req.axes = axes for index in range(field.nb_components): fname = f"{field.name}_{index}" G = expr_info.min_ghosts_per_field_name.get(fname, 0) if (field, index) in field_reqs: fi_req = field_reqs[(field, index)] if fi_req.axes: if req.axes is not None: assert set(fi_req.axes).intersection(req.axes) req.axes = tuple( set(req.axes).intersection(fi_req.axes) ) else: req.axes = fi_req.axes _min_ghosts = fi_req.min_ghosts.copy() _max_ghosts = fi_req.max_ghosts.copy() assert _min_ghosts[dim - 1 - direction] <= G assert _max_ghosts[dim - 1 - direction] >= G _min_ghosts[dim - 1 - direction] = G req.min_ghosts = npw.maximum(_min_ghosts, req.min_ghosts) req.max_ghosts = npw.minimum(_max_ghosts, req.max_ghosts) min_ghosts[index] = _min_ghosts.copy() else: req.min_ghosts[dim - 1 - direction] = max( G, req.min_ghosts[dim - 1 - direction] ) min_ghosts[index][dim - 1 - direction] = G assert req.min_ghosts[dim - 1 - direction] >= G assert req.max_ghosts[dim - 1 - direction] >= G if field not in min_ghosts_per_components: min_ghosts_per_components[field] = min_ghosts expr_info.min_ghosts = { k: npw.max(v, axis=0) for (k, v) in min_ghosts_per_components.items() } expr_info.min_ghosts_per_components = { field: gpc[:, -1 - direction] for (field, gpc) in min_ghosts_per_components.items() } for array, reqs in array_reqs: expr_info.min_ghosts[array] = reqs.min_ghosts.copy() expr_info.min_ghosts_per_components = reqs.min_ghosts[-1 - direction] return requirements
[docs] @debug def discretize(self, force_symbolic_axes=None): """Discretize variables and symbolic expressions.""" if self.discretized: return super().discretize() self._expr_info.discretize_expressions( input_dfields=self.input_discrete_fields, output_dfields=self.output_discrete_fields, force_symbolic_axes=force_symbolic_axes, )
[docs] @debug def setup(self, work): """Setup required work.""" self._expr_info.check_arrays() self._expr_info.check_buffers() super().setup(work) if work is None: raise ValueError("work is None.")